import numpy as np
import numpy.linalg as LA
from sympy import re, im, I, E, Symbol, sqrt
from sympy.functions.elementary.miscellaneous import cbrt
import argparse


def g_CauchyK_num(S):
    z = Symbol('z')
    ret = 0
    N = len(S)
    
    for j in range(N):

        ret += 1/(z + S[j] - I*np.sqrt(1/(2*N)) )
        ret += 1/(z - S[j]- I*np.sqrt(1/(2*N)) )
    
    return ret/(2*N)


def Estimator(S_s, gX, gS, SNR, a, c):
    
    N = len(S_s)
    
    output_X = np.zeros(N)
    output_Y = np.zeros(N)
    output_T = np.zeros(N)
    
    dfr = 32
    if SNR > 2:
        dfr = 64
    elif SNR > 4:
        dfr = 128
        
    z = Symbol('z')
    
    for i in range(N):
        
        
        zz = S_s[i] -  I*np.sqrt(dfr/(2*N))
        gS_eval = gS.subs(z,zz).evalf()
        
        #### optimal eigenvalue for X
        zeta = gS_eval + ((1-a)/a)*(1/zz)
        
        Z = (zz/zeta -1)/SNR
        
        Est = gX.subs(z,sqrt(Z)).evalf() + gX.subs(z,-sqrt(Z)).evalf()
        
        output_X[i] = im(((Est/zeta)/(2*SNR*im(gS_eval))).evalf())
        
        
        ### optimal singularvalue for Y
        q4 = -3 *c +   ( 3**(2/3) * (a * (-4+c**2) * zz - 2*a* (gS_eval**2)* zz+ 2* gS_eval *(-1+a+a* zz**2)) )/ \
        ( a*zz * cbrt( (9 * c * gS_eval * (-1+a * (1-gS_eval* zz+zz**2)))/(a * zz) \
                     + (1/3)*sqrt( (729 * (c**2) * (gS_eval**2) * (-1+a * (1-gS_eval* zz+zz**2))**2)/( (a**2) * zz**2) \
                               +(-3 * c**2+6 * (2+gS_eval**2-(gS_eval* (-1+a+a* zz**2))/(a * zz)))**3  ) )) \
                + cbrt( (27 * c * gS_eval * (-1+a * (1-gS_eval* zz+zz**2)))/(a * zz) \
                     + sqrt( (729 * (c**2) * (gS_eval**2) * (-1+a * (1-gS_eval* zz+zz**2))**2)/( (a**2) * zz**2) \
                               +(-3 * c**2+6 * (2+gS_eval**2-(gS_eval* (-1+a+a* zz**2))/(a * zz)))**3  ) )
        q4 = q4/6

        q4 = q4.evalf()

        output_Y[i] =  (im(q4)/(np.sqrt(SNR) * im(gS_eval))).evalf()
        
        
        ### optimal singularvalue for T
        H = re(gS_eval)
        output_T[i] = (S_s[i]- ( (1-a)/(a*S_s[i]) + 2*H ))/np.sqrt(SNR)
        
    return output_X, output_Y, output_T
    
def main():
    
    z = Symbol('z')
    p = argparse.ArgumentParser()

    p.add_argument('-a', type=float)
    p.add_argument('-s', type=float)
    p.add_argument('-c', type=int)
    
    args = p.parse_args()
    
    a = args.a
    c = args.c
    SNR = args.s

    N = 2000
    M = int(N/a)
    
    Ex = 10
    
        
    E_X_oracle = np.zeros(Ex)
    E_X_RIE = np.zeros(Ex)
        
    E_Y_oracle = np.zeros(Ex)
    E_Y_RIE = np.zeros(Ex)
    
    E_T_oracle = np.zeros(Ex)
    E_T_RIE = np.zeros(Ex)
    E_XY_RIE = np.zeros(Ex)

    for i in range(Ex):

        X = np.triu(np.random.normal(0, 1, (N,N)), 1)
        X = X + np.transpose(X) + np.diag(np.random.normal(loc=0, scale=np.sqrt(2), size=(N)))
        X = X/np.sqrt(N)
        X = X + c*np.eye(N)
        gX =  (z - c - sqrt(z-c-2)* sqrt(z-c+2))/2

        Y = np.random.randn(N,M)
        Y = Y/np.sqrt(N)
    
        W = np.random.randn(N,M)
        W = W/np.sqrt(N)
    
        T =  X @ Y

        ### Observation
        S = np.sqrt(SNR) * T + W
    
        ### SVD
        U_s, S_s , Vh_s = LA.svd(S)

        gS = g_CauchyK_num(S_s)

        ### Oracle Estimator for X & Y & T
        e_hat_X_oracle = np.zeros(N)
        s_hat_Y_oracle = np.zeros(N)
        s_hat_T_oracle = np.zeros(N)
        
        
        X_norm = LA.norm(X)**2
        Y_norm = LA.norm(Y)**2
        T_norm = LA.norm(T)**2
        
        for k in range(N):
            e_hat_X_oracle[k] = np.transpose(U_s[:,k])@X@U_s[:,k]
            
            s_hat_Y_oracle[k] = np.transpose(U_s[:,k])@Y@Vh_s[k,:]
            
            s_hat_T_oracle[k] = np.transpose(U_s[:,k])@T@Vh_s[k,:]
            
        X_hat_oracle = U_s@np.diag(e_hat_X_oracle)@np.transpose(U_s)
        E_X_oracle[i] = ( LA.norm(X-X_hat_oracle)**2) / X_norm
        
        SV_Y_oracle = np.hstack((np.diag(s_hat_Y_oracle),np.zeros((N,M-N))))
        Y_hat_oracle = U_s@SV_Y_oracle@ Vh_s
        E_Y_oracle[i] = ( LA.norm(Y-Y_hat_oracle)**2 ) / Y_norm
        
        SV_T_oracle = np.hstack((np.diag(s_hat_T_oracle),np.zeros((N,M-N))))
        T_hat_oracle = U_s@SV_T_oracle@ Vh_s
        E_T_oracle[i] = ( LA.norm(T-T_hat_oracle)**2 ) / T_norm

    
        e_hat_X, s_hat_Y, s_hat_T = Estimator(S_s, gX, gS, SNR, a, c)
        #### RIE for X
        X_hat = U_s@np.diag(e_hat_X)@np.transpose(U_s)
        E_X_RIE[i] = ( LA.norm(X-X_hat)**2 ) / X_norm
        
        ### RIE for Y
        SV_Y_RIE = np.hstack((np.diag(s_hat_Y),np.zeros((N,M-N))))
        Y_hat = U_s@SV_Y_RIE@ Vh_s
        E_Y_RIE[i] = ( LA.norm(Y-Y_hat)**2) / Y_norm
        
        ### RIE for T
        SV_T_RIE = np.hstack((np.diag(s_hat_T),np.zeros((N,M-N))))
        T_hat = U_s@SV_T_RIE@ Vh_s
        E_T_RIE[i] = ( LA.norm(T-T_hat)**2) / T_norm
        
        E_XY_RIE[i] = ( LA.norm(T-X_hat@Y_hat)**2 ) / T_norm
        
        with open('SNR='+str(SNR)+'.txt', 'a') as f:
            f.write(str( E_X_oracle)+'\n')
            f.write(str( E_X_RIE)+'\n\n')
            f.write(str( E_Y_oracle)+'\n')
            f.write(str( E_Y_RIE)+'\n\n')
            f.write(str( E_T_oracle)+'\n')
            f.write(str( E_T_RIE)+'\n')
            f.write(str( E_XY_RIE)+'\n\n\n\n\n\n')

    filename = 'MF-X_c='+str(c)+'_SNR='+str(SNR)+'_Oracle.npy'
    np.save( filename, E_X_oracle)
    print(E_X_oracle)
    filename = 'MF-X_c='+str(c)+'_SNR='+str(SNR)+'_RIE.npy'
    np.save( filename, E_X_RIE)
    print(E_X_RIE)
        
    filename = 'MF-Y_c='+str(c)+'_SNR='+str(SNR)+'_Oracle.npy'
    np.save( filename, E_Y_oracle)
    print(E_Y_oracle)
    filename = 'MF-Y_c='+str(c)+'_SNR='+str(SNR)+'_RIE.npy'
    np.save( filename, E_Y_RIE)
    print(E_Y_RIE)
    
    filename = 'MF-T_c='+str(c)+'_SNR='+str(SNR)+'_Oracle.npy'
    np.save( filename, E_T_oracle)
    print(E_T_oracle)
    filename = 'MF-T_c='+str(c)+'_SNR='+str(SNR)+'_RIE.npy'
    np.save( filename, E_T_RIE)
    print(E_T_RIE)
    filename = 'MF-XY_c='+str(c)+'_SNR='+str(SNR)+'_RIE.npy'
    np.save( filename, E_XY_RIE)
    print(E_XY_RIE)
    

    

#
if __name__ == "__main__":
    main()

